技術問答
技術文章
iT 徵才
Tag
聊天室
2024 鐵人賽
登入/註冊
問答
文章
Tag
邦友
鐵人賽
搜尋
2021 iThome 鐵人賽
DAY
12
0
AI & Data
Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型
系列 第
12
篇
Day-11 Backpropagation 介紹
13th鐵人賽
CrazyFire
2021-09-26 13:53:36
12246 瀏覽
分享至
我們前面提到過深度學習就是模仿神經網路建構一個龐大的訓練模型,來達到特徵的選取(調整 weight 的數值來達到決定輸入特徵的權重),那我們看過 Gradient Descent 的數值更新狀況概念很簡單,但實際上我們可以想像當結構變複雜之時,我們可以預期 Gradient Descent 的計算將會變得太過複雜
Baskpropagation(反向傳遞法),就是希望讓 neural network 的 training 變得更加有效率
回顧一下 Gradient Descent
Network parameters
先選擇一個初始的參數
,然後計算這個
對於我們的 loss function 的 Gradient
,也就是計算每一個 network 裡面的參數對於
的偏微分
那我們就會拿到 Gradient,這個 Gradient 會是一個 Vector,就可以利用 Vector 來更新我們的參數
那我們會重複這個流程直到我們的期望次數,所以可以發現在一般的 Logistic Regression 或是 Linear Regression 在這邊的操作是沒太多區別的,唯一的問題是 Neural network 的參數非常的多,我們的 Gradient Vector 會變得非常巨大,所以如何有效地去計算這個 Vector,就是 Backpropagation 在做的事情
所以 Backpropagation 並不是一個全新的方法,他說白了就是 Gradient Descent,只是它是一個更有效率的演算法,目的在於更有效率地去取得 Gradient Vector,這也是為什麼之後提到的 PyTorch Gradient Calculation 會交給 Backpropagation 做計算
About Backpropagation
我們提到過 Backpropagation 可以想成一個更有效率的 Gradient Descent 了,那 Backpropagation 有沒有特別需要注意的部分呢?
對於 Backpropagation 最重要的的觀念就是 Chain Rule(連鎖律)
Chain Rule
Chain Rule 連鎖律其實就是在強調數值之間的關係,那這邊為甚麼會這麼重要是因為回顧一下神經網路傳遞的方式,他們是一層一層的往下傳遞,因此就最終結果而言,其實是受到初始參數的影響一路往下層層變化的,那這些參數之間對於結果的關係是什麼?其實就會受到連鎖律的影響,因此基本的連鎖律概念我們在這裡簡單的幫大家 Summarize 一下
Case 1:
的話,如果 x 受到影響,會影響到 y ,進而影響到 z,也就是
所以如果我們今天要計算
,可以先把它轉換成
Case 2:
也就是說
,還有
,s 透過了兩個路徑去影響到了 z
所以如果我們今天要計算
,可以先把它轉換成
我們已經回顧了基本的 Chain Rule 在微分時會需要注意的部分,讓我們回到 Nueral Network
Basic Nueral Network
我們回到基本的訓練過程去做思考,今天我們的 Nueral Network 在做訓練的過程是怎麼做訓練的?就是我們傳遞了一筆資料,經過神經網路的計算之後,會得到一個答案,那這個答案可能跟我們的預期答案有所落差,因此我們就可以利用這個落差的總和得到我們的 total loss
所以這邊的
就代表著
跟
之間的落差
那如果我們對 loss 和某一個 w 去做偏微分,我們可以發現就等於我把每個參數的 loss 對特定參數 w 的微分加總,就是 loss 對指定的 w 做偏微分了,因此我們之後就可以不用考慮去計算
,而改思考對某一筆 data 的
就可以了
那我們從一個簡單的 Neural network 來看看,假設我們有一個 network 長下面這樣
那我們從某一個 neuron 來看看
那我們今天要算
要怎麼算,依照 Chain Rule 我們可以拆成兩項,也就是
那計算
其實是非常簡單的,我們稱為 Forward pass,那計算
我們則稱為 Backward pass,那為啥要叫 forward 跟 backward 我們等等就知道了
Forward pass
先來看看怎麼計算 Forward pass,我們前面有說我們的
了,所以如果我們希望計算
,其實就是
,
,其實就是
所以我們可以發現一個規律,當我們想找
,事實上就是去看那個 w 前面接的參數,也就是這個神經元的
輸入
因此如果我們希望找到所有的
,就必須先就算
正向
的參數,也就是我們 input 參數進入之後,一路往下到輸出的所有一層一層傳遞的參數,這也是為甚麼我們稱其為 forward pass,因為就是我們一般求輸出的
正向
運算
那這邊也是為甚麼我們說找
是非常簡單的,因為根本就是輸入參數
Backward pass
那如果我們已經知道 Forward pass 就是順向/正向運算,那 Backward pass 顧名思義應該就是反向運算了,但是要怎麼做呢?
我們現在要算
,我們知道 z 好取得,但是 C 就是要繼續往下看一路運算到最後結果,這是非常複雜的,那怎麼辦呢?那我們試著再用 Chain rule 拆解看看這一項
from:
ML Lecture 7: Backpropagation
我們先假設接在 Z 後的 activation function(我們之後再解釋 QQ) 是 sigmoid function
,然後輸出了一個結果
,那我們先不管後面的部分,我們在多了一個變數
之後,就可以利用 Chain rule 再把式子拆分成
那我們先來看
是什麼,我們已經知道
,所以
其實就是
,也就是 sigmoid function 的微分
那
應該長怎樣呢? 應該長
那我們看上圖可以發現
,
其實就是後面的
,
,那
,
呢?怎麼感覺又繞回來一圈了?
我們先整理一下現在
會長怎樣?
,換句話說我們其實只差最後一個步驟了,也就是我們只差知道
,
整個問題就結束了,但是怎麼解?我們換個方向想
如果我們從後面往前推,也就是我們把目標先放在答案那邊,從 output layer 往前推
我們可以得到
,
,我們會發現因為
,
都是已知了,因為我們正向運算一定會算出一個答案,我們的
,
就可以利用 Cost function 來決定(例如 MSE),然後
,
也可以運算了
那如果現在不是在 output layer 呢?其實就是一直往下推一路到 output layer 就可以了,因為只有在 output layer ,我們才有辦法把
這種部份算出來
所以概念上我們就是完全倒過來,從結果一路回推所有的
每日小結
Backpropagation 可以說是深度學習裡面最重要的觀念了,神經網路的構造複雜,本來就很難去計算和更新參數,因此普通的 Gradient Descent 會遇到很多計算上的困難, Backpropagation 則是利用 Chain Rule 的方式,將計算複雜度大大的下降,並利用一次 Forward pass 加一次 Backward pass 來達到快速更新參數計算參數的方式
本日課程大量參考
李弘毅老師的開放式課程
,這份教學非常非常好理解 Backpropagation,因此上面看不懂的部分都可以再去看看,筆者當初在學習的過程中,也深受此系列幫助
到這裡我們已經完成了基本的觀念架設了,明天我們就可以開始介紹 PyTorch Framework 了~
留言
追蹤
檢舉
上一篇
Day-10 深度學習的介紹
下一篇
Day-12 Pytorch 介紹
系列文
Deep Learning 從零開始到放棄的 30 天 PyTorch 數字辨識模型
共
31
篇
目錄
RSS系列文
訂閱系列文
27
人訂閱
27
Day-26 手把手的手寫辨識模型 0x1:資料集整理
28
Day-27 手把手的手寫面是模型 0x2:資料訓練和結果輸出
29
Day-28 手把手的手寫辨識模型 0x3:CNN is the end?模型大哉問
30
Day-29 Pytorch 還可以更輕鬆更簡單!Pytorch Lightning
31
Day-30 不完美收工
完整目錄
直播研討會
{{ item.subject }}
{{ item.channelVendor }}
{{ item.webinarstarted }}
|
{{ formatDate(item.duration) }}
直播中
立即報名
尚未有邦友留言
立即登入留言
iThome鐵人賽
參賽組數
1064
組
團體組數
40
組
累計文章數
22195
篇
完賽人數
600
人
看影片追技術
看更多
{{ item.subject }}
{{ item.channelVendor }}
|
{{ formatDate(item.duration) }}
直播中
熱門tag
看更多
15th鐵人賽
16th鐵人賽
13th鐵人賽
14th鐵人賽
12th鐵人賽
11th鐵人賽
鐵人賽
2019鐵人賽
javascript
2018鐵人賽
python
2017鐵人賽
windows
php
c#
windows server
linux
css
react
vue.js
熱門問題
請問內網IP如何轉外網IP?
如何寫公式才能利用excel 觸發一個數據時傳送一個訊息給 自已的line呢?有沒有可以用其它方式,來取代line notify 的方法,因為line 開始收費
新手學習編程,哪種編程語言好?
Windows7升級Windows10後網路功能異常
python爬蟲 動態生成網頁104人力銀行
區域網路問題提問
vmware 虛擬機(windows)裡顯示使用容量與實際檔案容量不符合
防火牆與DNS請教
FORTI 防火牆使用 RADIUS 認證問題請教
2台 Hyper-V 2008 R2 叢集主機(硬體規格相同), 如何加入一台新機? 謝謝.
熱門回答
請問內網IP如何轉外網IP?
防火牆與DNS請教
這樣的物件設計好嗎?
新手學習編程,哪種編程語言好?
區域網路問題提問
熱門文章
每日一篇學習筆記 直到我做完專題 :( [Day33]
每日一篇學習筆記 直到我做完專題 :( [Day34]
每日一篇學習筆記 直到我做完專題 :( [Day35]
每日一篇學習筆記 直到我做完專題 :( [Day36]
EPS 到 JPG 的高效轉檔策略:Python 實作範例
IT邦幫忙
×
標記使用者
輸入對方的帳號或暱稱
Loading
找不到結果。
標記
{{ result.label }}
{{ result.account }}